package io.leavesfly.tinyai.nnet.v2.init;

import io.leavesfly.tinyai.ndarr.NdArray;
import java.util.Random;

/**
 * 正态分布初始化器
 * <p>
 * 从正态分布 N(mean, std²) 中采样初始化张量
 *
 * @author leavesfly
 * @version 2.0
 */
public class NormalInitializer implements Initializer {

    private final float mean;
    private final float std;
    private final Random random;

    /**
     * 构造函数
     *
     * @param mean 均值
     * @param std  标准差
     */
    public NormalInitializer(float mean, float std) {
        this.mean = mean;
        this.std = std;
        this.random = new Random();
    }

    @Override
    public void initialize(NdArray tensor) {
        float[] data = tensor.getArray();
        for (int i = 0; i < data.length; i++) {
            data[i] = mean + (float) (random.nextGaussian() * std);
        }
    }
}
